-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shard MHA. #3115
Shard MHA. #3115
Conversation
28fadd2
to
0ec0749
Compare
!build |
!build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
[0, 2, 3], | ||
), | ||
) | ||
T152_matmul = self.ops.sum(T152_local_matmul, [0]) # allreduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what would happen currently if we do not decompose the matmul and the allreduce...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first thing that'll break is that linear
will produce a wrong shape. linear, as is implemented today, will output a tensor of rank input_rank + weight_rank - 2 = 5
. However, we want the shape to be [d,b,s,e]
and thus 4D.
T131 = self.ops.permute(T130, dims=[0, 2, 1, 3]) | ||
T137 = self.ops.reshape(T117, new_shape=[b, s, h, e // h]) | ||
T138 = self.ops.permute(T137, dims=[0, 2, 1, 3]) | ||
T123 = self.ops.reshape(T104, new_shape=[d, b, s, h // d, e // h]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, we pass from shape [d, b, s, e//d]
to [d, b, s, h//d, e//h]
. Nothing illegal about it but it looks surprising to me so I just want to make sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I double checked -- it looks right. MHA is head parallel according to Figure 3b in https://arxiv.org/pdf/1909.08053.
!build |
As a follow-up to #3045.
For #2199.